import numpy as np
# import pykalman
import matplotlib.pyplot as plt
import scipy.stats as stats
import ekf_testv6 as ekf
import summary_of_sims as sum_sims
import copy
import random
import matplotlib.transforms as mtransforms

twTR=sum_sims.twTR

ocn_sims = np.genfromtxt(open("OHCcomb.csv", "rb"),dtype=float, delimiter=',')

[sms, smyrs]=np.shape(sum_sims.twTR)
smyrs=ekf.n_iters
#dif=0.21
twTRmean=sum_sims.twTRmean
twTRstd=sum_sims.twTRstd
smdate=1850

dates=ekf.dates
temps=ekf.temps


wt1=2; wt2=1;
colorekfsims=[0,0,0]
colorekfsimsl=[0,0,0]
for c in range(3):
    colorekfsims[c]=(ekf.colorekf[c]*wt1+ekf.colorstate[c]*wt2)/(wt1+wt2)
    colorekfsimsl[c]=int(colorekfsims[c]*255)

colorekfsims="orange"
colorekfsimsl=[215,100,61]



def KL_div(a_samps, bmeans, bstdevs,useStds=True): #a_samps is n x t, where t is the number of years
    sza = np.shape(a_samps)
    #calculate combined idealized distribution
    sz = np.shape(bmeans)
    cmean=bmeans
    cstdev =bstdevs
    if ( sz[0]>1 and len(sz)>1):
        cmean = np.mean(bmeans, axis=0)
        if(useStds):
            cstdev = np.sqrt((np.mean(np.square(bstdevs), axis=0) + (sz[0]-1)*np.var(bmeans,axis=0,ddof=1))/sz[0])
        else:
            cstdev = np.std(bmeans,axis=0,ddof=1)
    
    a_tormed = (a_samps - np.tile(cmean, (sza[0], 1)) ) / cstdev
    ahist, bin_edges = np.histogram(a_tormed,np.linspace(-20,20,num=401,endpoint=True), density=True)
    bideal = stats.norm.pdf((bin_edges[1:]+bin_edges[0:-1])/2)
    return np.sum(np.where(ahist != 0, ahist * np.log(ahist / bideal), 0))


fig = plt.figure(figsize=(7,7))
ax_dict = fig.subplot_mosaic(
"""
aa
de
""",
gridspec_kw={
    "width_ratios": [2,1], "height_ratios": [6,3], "wspace": 0.35, "hspace": 0.35, "left":0.13, "right":0.9
},)


fig1b = plt.figure(figsize=(5,5))
ax_dict1b = fig1b.subplot_mosaic(
"""
bc
de
""",
gridspec_kw={
    "width_ratios": [1, 1], "height_ratios": [1, 1], "wspace": 0.35, "hspace": 0.45, "left":0.1, "right":0.96
},)

fig2 = plt.figure(figsize=(7,7))
ax_dict2 = fig2.subplot_mosaic(
"""
aa
de
""",
gridspec_kw={
    "width_ratios": [2,1], "height_ratios": [6,3], "wspace": 0.35, "hspace": 0.35, "left":0.1, "right":0.9
},)
fig2b = plt.figure(figsize=(5,5))
ax_dict2b = fig2b.subplot_mosaic(
"""
bc
de
""",
gridspec_kw={
    "width_ratios": [1, 1], "height_ratios": [1, 1], "wspace": 0.35, "hspace": 0.45, "left":0.1, "right":0.96
},)

ax_dict2["a"].plot(ekf.dates,ekf.xh1d*ekf.zJ_from_W,'-',label='a posteriori EBM-KF-uf state $\hat{H}_t$ using Zanna (2019)', color=ekf.colorekf,zorder=4)
ax_dict2["a"].fill_between(dates[4:], (ekf.xh1d[4:]-2*ekf.stdPd[4:])*ekf.zJ_from_W, (ekf.xh1d[4:]+2*ekf.stdPd[4:])*ekf.zJ_from_W,label="95% CI ($\pm 2\sqrt{\hat{p}^H_t}$) of EBM-KF-uf OHCA state $\hat{H}_t$ using Zanna (2019)", color=ekf.colorstate,alpha=0.25,zorder=2)
ax_dict2["a"].set_xticks(np.arange(1850,2025+1,25))
ax_dict2["a"].set_xlim(1850,2025)
ax_dict2["a"].set_xlabel('Year')

plt.sca(ax_dict["a"])
ekf.plot_boilerplate(ax_dict["a"])
plt.yticks(np.arange(286.2,288.4,0.2))
ax_dict["a"].set_ylim(286.3,288.1)
ax_dict["a"].plot(ekf.dates,ekf.xh1s,'-',label='posterior EBM-KF-uf GMST state $\hat{T}_t$ using real HadCRUT5', color=ekf.colorekf,zorder=4)
ax_dict["a"].fill_between(dates[4:], ekf.xh1s[4:]-2*ekf.stdP[4:], ekf.xh1s[4:]+2*ekf.stdP[4:],label="95% CI ($\pm 2\sqrt{\hat{p}^T_t}$) of EBM-KF-uf state $\hat{T}_t$ using real HadCRUT5", color=ekf.colorstate,alpha=0.25,zorder=2)
xh1ds=copy.deepcopy(ekf.xh1s)
stdPs=copy.deepcopy(ekf.stdP)
xh1d=copy.deepcopy(ekf.xh1d)
stdPd=copy.deepcopy(ekf.stdPd)
avglen=30
avglend2=0
all_xhats=np.zeros([sms, smyrs])
all_xhatd=np.zeros([sms, smyrs])-4000
all_Ps=np.zeros([sms, smyrs])
all_Pd=np.zeros([sms, smyrs])
all_qqys=np.zeros([sms, smyrs])+350
all_qqyd=np.zeros([sms, smyrs])+350
for i in range(sms):
    cur_temps=twTR[i,:]
##    len_sim=smyrs
    len_sim=0
    try:
        len_sim = list(cur_temps).index(float("nan"))
    except ValueError:
        len_sim=len(cur_temps)
    if(len_sim>174):
        len_sim=174
    observ = np.transpose(np.array([cur_temps[:len_sim],ocn_sims[i,:len_sim]/ekf.zJ_from_W])) ##NOTE - should I englarge the LENS2 OHC by 5/3 to match Zanna obs?
    sim_xavg, this_Ps=ekf.ekf_run(observ,len_sim, retPs=True)

    dateslice=dates[(sum_sims.smdate-sum_sims.sdate):(sum_sims.smdate-sum_sims.sdate+len_sim)]
    datesindices=range(avglend2,smyrs-avglend2)
    
    all_xhats[i,:]=sim_xavg[:,0]
    all_xhatd[i,:]=sim_xavg[:,1]
    all_Ps[i,:]=np.sqrt(np.abs(this_Ps[:,0,0]))
    all_Pd[i,:]=np.sqrt(np.abs(this_Ps[:,1,1]))
    all_qqys[i,:]=(sim_xavg[:,0]-xh1ds[0+avglend2:len_sim-avglend2])/stdPs[0+avglend2:len_sim-avglend2]
    all_qqyd[i,:]=(sim_xavg[:,1]-xh1d[0+avglend2:len_sim-avglend2])/stdPd[0+avglend2:len_sim-avglend2]
    if i==(sms-1):
        ax_dict["a"].plot(dateslice,sim_xavg[:,0],'-',label='posterior EBM-KF-uf GMST states using LENS2 sims', color=colorekfsims, linewidth=0.5,zorder=1)
        ax_dict2["a"].plot(dateslice,sim_xavg[:,1]*ekf.zJ_from_W,'-',label='posterior EBM-KF-uf OHCA states using LENS2 sims', color=colorekfsims, linewidth=0.5,zorder=1)
        #plt.plot(dateslice,ocn_sims[i,:len_sim]/ekf.zJ_from_W,'-',label='OHCA from simulations $(\\Psi_t)_j$', color=[52/256,235/256,235/256], linewidth=0.5,zorder=2)
    else:
        ax_dict["a"].plot(dateslice,sim_xavg[:,0],'-', color=sum_sims.colorgen(i,(70,30,30),colorekfsimsl), linewidth=0.5,zorder=1)
        ax_dict2["a"].plot(dateslice,sim_xavg[:,1]*ekf.zJ_from_W,'-', color=sum_sims.colorgen(i,(70,30,30),colorekfsimsl), linewidth=0.5,zorder=1)
        #ax_dict2["a"].plot(dateslice,ocn_sims[i,:len_sim],'-', color=sum_sims.colorgen(i,(70,30,30),(52,235,235)), linewidth=0.5,zorder=2)
all2_xhats=np.where(all_xhats == 0, float("nan"), all_xhats)
all2_qqys=np.where(all_qqys == 350, float("nan"), all_qqys)
all2_xhatd=np.where(all_xhatd == -4000, float("nan"), all_xhatd)
all2_qqyd=np.where(all_qqyd == 350, float("nan"), all_qqyd)
mean_xhats=np.nanmean(all_xhats,axis=0)
mean_xhatd=np.nanmean(all_xhatd,axis=0)
std_xhats=np.nanstd(twTR,axis=0) #/np.sqrt(sms)

print(np.mean(std_xhats))
print(np.mean(stdPs))

ax_dict["a"].plot(ekf.dates[avglend2:smyrs-avglend2],mean_xhats,'-',label='mean simulated EBM-KF-uf GMST state', color='black',zorder=3)
ax_dict2["a"].plot(ekf.dates[avglend2:smyrs-avglend2],mean_xhatd*ekf.zJ_from_W,'-',label='mean simulated EBM-KF-uf OHCA state', color='black',zorder=3)

r2 = ekf.r2_score(mean_xhats, ekf.xh1s)
print('r2 score for GMST ens mean vs EBM-KF is', r2)
r2 = ekf.r2_score(mean_xhatd*ekf.zJ_from_W, ekf.xh1d*ekf.zJ_from_W)
print('r2 score for OHCA ens mean vs EBM-KF is', r2)

#plt.fill_between(dates, xh1ds-2*stdPs, xh1ds+2*stdPs,label="95% CI of EBM-KF state from measurements ($2\sqrt{P_{n}}$)", color='seagreen',alpha=0.2,zorder=2)
plt.title("Simulated EBM-KF-uf GMST States vs HadCRUT5 EBM-KF-uf State Uncertainty")
handles, labels = ax_dict["a"].get_legend_handles_labels()
order = [0,3,2,1]
ax_dict["a"].legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper left')
handles, labels = ax_dict2["a"].get_legend_handles_labels()
ax_dict2["a"].legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper left')

mn, mx = ax_dict["a"].get_ylim()
ax1b = ax_dict["a"].twinx()
ax1b.set_yticks(np.arange(13.2,15.2,0.2))
ax1b.set_ylim(mn-273.15, mx-273.15)
ax1b.set_ylabel('Temperature (°C)')
#ax1b.yaxis.set_label_coords(0.95, 0.5)

mn, mx = ax_dict2["a"].get_ylim()
ax_dict2["a"].set_ylabel('Heat (zJ)')
ax_dict2["a"].set_xlim([1850,2025])
ax_dict2["a"].set_xlabel("Year")
ax_dict2["a"].set_title("Simulated EBM-KF-uf OHCA States vs Zanna EBM-KF-uf State Uncertainty")
ax1b = ax_dict2["a"].twinx()
ax1b.set_yticks(np.arange(-1,8,1))
ax1b.set_ylim(mn*ekf.zJtomm/10, mx*ekf.zJtomm/10)
ax1b.set_ylabel('Thermosteric Sea Level Rise (cm)')
#ax1b.yaxis.set_label_coords(0.95, 0.3)

lpad=0 
ax1=ax_dict1b["b"]
ax2=ax_dict1b["c"]

plt.sca(ax2)
stats.probplot(all2_qqys[~np.isnan(all2_qqys)],dist="norm", plot=plt)
ax2.get_lines()[0].set_markerfacecolor(ekf.colorstate)
ax2.get_lines()[0].set_markeredgewidth(0)
ax2.get_lines()[1].set_color(ekf.pcolor)
ax2.get_lines()[1].set_linewidth(1.5)
ax2.get_lines()[0].set_markersize(3.0)
ax2.set_xlabel("Theoretical Quantiles",labelpad=lpad)
ax2.set_title("Real-Sims QQ GMST")

#fig.suptitle("Variance of Simulated EBM-KF States $\\neq$ State Variance ($P$) of Measurements")

plt.sca(ax1)
ax1.hist(all2_qqys[~np.isnan(all2_qqys)],bins=sum_sims.nbins, density=True,color=ekf.colorstate)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax1.plot(xnorm,ynorm, color=ekf.pcolor)
ax1.set_xlim([-6,6])
ax1.set_xlabel("State Std Dev",labelpad=lpad)
ax1.set_ylabel("Probability Density")
ax1.set_title("Real-Sims Dist. GMST")


# Calculate all differences relative to eachother simulation
allall_qqys=np.zeros([sms, smyrs*sms])+50
allabsdfs=np.zeros([sms, smyrs*sms])+50
mean_qqys=np.zeros(sms)
meandfs=np.zeros(sms)
stdev_qqys=np.zeros(sms)

for i in range(sms):
    for j in range(sms):
        if (i==j):
            continue
        else:
            allall_qqys[i,j*smyrs:(j+1)*smyrs]= (all_xhats[i,:] - all_xhats[j,:])/all_Ps[i,:]
            allabsdfs[i,j*smyrs:(j+1)*smyrs]= (all_xhats[i,:] - all_xhats[j,:])
allall2_qqys=np.where(allall_qqys == 50, float("nan"), allall_qqys)
all2absdfs=np.where(allabsdfs == 50, float("nan"), allabsdfs)
for i in range(sms):
    mean_qqys[i]=np.nanmean(allall2_qqys[i,:])
    meandfs[i]=np.nanmean(all2absdfs[i,:])
    stdev_qqys[i]=np.nanstd(allall2_qqys[i,:])

ax4=ax_dict["d"]
ax4.plot(mean_qqys[1:],stdev_qqys[1:],'.',color=colorekfsims,label="Sim-Sims comparisons")
ax4.plot(np.nanmean(allall2_qqys),np.nanstd(allall2_qqys),'x',color='black',label="Centroid Sim-Sims compars",markersize=10)
ax4.plot(np.nanmean(all2_qqys),np.nanstd(all2_qqys),'*',markeredgewidth=1, markeredgecolor='k',markersize=10, color=ekf.colorstate,label="Real-Sims comparison")
ax4.set_xlabel("LENS2 States Ensemble Mean Bias in Std Devs")
ax4.set_ylabel("LENS2 States Ensemble Std Dev \n Relative to Prediction")
#ax4.yaxis.set_label_coords(-.07, .5)
ax4.plot([0,0],[0,3],'-',color=ekf.pcolor,zorder=0)
ax4.plot([-2,3],[1,1],'-',color=ekf.pcolor,zorder=0)
ax4.set_ylim([0.4,1.8])
ax4.set_xlim([-1.1,1.1])

ax4.set_title("Distribution of Simulated EBM-KF-uf States \n relative to Prediction from a Single Run")
print(np.mean(abs(mean_qqys)),np.max(mean_qqys),np.min(mean_qqys))
print(np.mean(abs(meandfs)),np.max(meandfs),np.min(meandfs),'K')

axKL=ax_dict["e"]
axKL.set_title("Kullback–Leibler Divergence from  \n Simulated EBM-KF-uf State Ensemble \n to Predicted Distribution")
axKL.plot(0,KL_div(all_xhats,xh1ds, stdPs),'*',markeredgewidth=1, markeredgecolor='k',color=ekf.colorstate, markersize=10)
KL_1s=np.empty([sms])
KL_2s=np.empty([sms*(sms-1)]);
nsKL = int(sms*(sms-1)/2)
KL_3s=np.empty([nsKL]); KL_3p=np.empty([nsKL])
KL_8s=np.empty([nsKL]); KL_8p=np.empty([nsKL])
for i in range(sms):
    KL_1s[i]=KL_div(all_xhats,all_xhats[i], all_Ps[i])
gidx = int(0.25 * (sms - 1)+.5)
goodKL = np.argpartition(KL_1s, gidx)[gidx]
worstKL=np.argmax(KL_1s)
ax4.plot(mean_qqys[goodKL],stdev_qqys[goodKL],'.',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',
         label="Good, KL="+format(KL_1s[goodKL], ".2f"),markersize=10)
worstKL=np.argmax(KL_1s)
ax4.plot(mean_qqys[worstKL],stdev_qqys[worstKL],'s',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',
         label="Worst, KL="+format(KL_1s[worstKL], ".2f"),markersize=5)
ax4.legend(loc="best",fontsize="8",ncol=2)
for i in range(sms):
    for j in range(sms-1):
        KL_2s[j*sms+i]=KL_div(all_xhats,[all_xhats[i],all_xhats[j]], [all_Ps[i],all_Ps[j]])
for i in range(nsKL):
    a3=random.sample(range(0, sms), 3)
    KL_3s[i]=KL_div(all_xhats,all_xhats[a3], all_Ps[a3])
    KL_3p[i]=KL_div(all_xhats,all_xhats[a3], all_Ps[a3],False)
    a8=random.sample(range(0, sms), 8)
    KL_8s[i]=KL_div(all_xhats,all_xhats[a8], all_Ps[a8])
    KL_8p[i]=KL_div(all_xhats,all_xhats[a8], all_Ps[a8],False)
    
KL_all=[KL_1s,KL_2s,KL_3s, KL_8s, KL_3p,KL_8p]
vplot=axKL.violinplot(KL_all, positions=[1,2,3,4,5,6],showmeans=True,showextrema=False)
labelsKL=[1,2,3,8,3,8]
#for patch, color in zip(vplot['bodies'], colors):
for i in range(6):
    axKL.text(i+1,np.mean(KL_all[i])*0.5,str(labelsKL[i]),ha='center')
    patch=vplot['bodies'][i]
    if i<4:
        patch.set_color(colorekfsims)
    else:
        patch.set_color("red")

axKL.text(0,30,"# of Sims")
axKL.set_yscale('log')
axKL.set_ylabel('Bits')
axKL.set_xticks([2.5, 5.5])
axKL.set_xticklabels(['with $\hat{p}^T_t}$ as \n prior on var.', 'sample \n variance'])

axlabels=['a)','b)','c)','d)','e)']
axids=['a','d','e']
for i in range(len(ax_dict)):
    label=axlabels[i]
    ax=ax_dict[axids[i]]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-25/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')

ax3=ax_dict1b["e"]
plt.sca(ax3)
ax3.hist(allall2_qqys[goodKL,~np.isnan(allall2_qqys[goodKL,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,6])
ax3.set_xlabel("State Std Dev",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("Good Sim-Sims Dist.")

ax3=ax_dict1b["d"]
plt.sca(ax3)
ax3.hist(allall2_qqys[worstKL,~np.isnan(allall2_qqys[worstKL,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,6])
ax3.set_xlabel("State Std Dev",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("Worst Sim-Sims Dist.")

axids2=['b','c','d','e']
for i in range(len(ax_dict1b)):
    label=axlabels[i]
    ax=ax_dict1b[axids2[i]]
    trans = mtransforms.ScaledTranslation(-10/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')


print(np.nanstd(allall2_qqys))
print(np.nanmean(all_Ps))

ax1=ax_dict2b["b"]
ax2=ax_dict2b["c"]

plt.sca(ax2)
stats.probplot(all2_qqyd[~np.isnan(all2_qqyd)],dist="norm", plot=plt)
ax2.get_lines()[0].set_markerfacecolor(ekf.colorstate)
ax2.get_lines()[0].set_markeredgewidth(0)
ax2.get_lines()[1].set_color(ekf.pcolor)
ax2.get_lines()[1].set_linewidth(1.5)
ax2.get_lines()[0].set_markersize(3.0)
ax2.set_xlabel("Theoretical Quantiles",labelpad=lpad)
ax2.set_title("Real-Sims QQ OHCA")

#fig.suptitle("Variance of Simulated EBM-KF States $\\neq$ State Variance ($P$) of Measurements")

plt.sca(ax1)
ax1.hist(all2_qqyd[~np.isnan(all2_qqyd)],bins=sum_sims.nbins, density=True,color=ekf.colorstate)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax1.plot(xnorm,ynorm, color=ekf.pcolor)
ax1.set_xlim([-15,10])
ax1.set_xlabel("State Std Dev from $P_t$",labelpad=lpad)
ax1.set_ylabel("Probability Density")
ax1.set_title("Real-Sims Dist. OHCA")

# Calculate all differences relative to eachother simulation
allall_qqyd=np.zeros([sms, smyrs*sms])+350
allabsdfd=np.zeros([sms, smyrs*sms])+50
mean_qqyd=np.zeros(sms)
stdev_qqyd=np.zeros(sms)
meandfd=np.zeros(sms)
for i in range(sms):
    for j in range(sms):
        if (i==j):
            continue
        else:
            allall_qqyd[i,j*smyrs:(j+1)*smyrs]= (all_xhatd[i,:] - all_xhatd[j,:])/all_Pd[i,:]
            allabsdfd[i,j*smyrs:(j+1)*smyrs]= (all_xhatd[i,:] - all_xhatd[j,:])
allall2_qqyd=np.where(allall_qqyd == 350, float("nan"), allall_qqyd)
all2absdfd=np.where(allabsdfd == 50, float("nan"), allabsdfd)
for i in range(sms):
    mean_qqyd[i]=np.nanmean(allall2_qqyd[i,:])
    stdev_qqyd[i]=np.nanstd(allall2_qqyd[i,:])
    meandfd[i]=np.nanmean(all2absdfd[i,:])


ax4=ax_dict2["d"]
ax4.plot(mean_qqyd[1:],stdev_qqyd[1:],'.',color=colorekfsims,label="Sim-Sims comparisons")

ax4.plot(np.nanmean(allall2_qqyd),np.nanstd(allall2_qqyd),'x',color='black',label="Centroid Sim-Sims compars",markersize=10)
ax4.plot(np.nanmean(all2_qqyd),np.nanstd(all2_qqyd),'*',markeredgewidth=1, markeredgecolor='k',color=ekf.colorstate,label="Real-Sims comparison",markersize=10)
ax4.set_xlabel("LENS2 States Ensemble Mean Bias in Std Devs")
ax4.set_ylabel("LENS2 States Ensemble Std Dev \n Relative to Prediction")
#ax4.yaxis.set_label_coords(-.07, .5)
ax4.plot([0,0],[0,6],'-',color=ekf.pcolor,zorder=0)
ax4.plot([-3,3],[1,1],'-',color=ekf.pcolor,zorder=0)
ax4.set_ylim([0.4,5])
ax4.set_xlim([-2.7,2.7])
ax4.set_title("Distribution of Simulated EBM-KF-uf States \n relative to Prediction from a Single Run") 
print(np.mean(abs(mean_qqyd)),np.max(mean_qqyd),np.min(mean_qqyd))
print(np.mean(abs(meandfd))*ekf.zJ_from_W,np.max(meandfd)*ekf.zJ_from_W,np.min(meandfd)*ekf.zJ_from_W,'ZJ')


axKL=ax_dict2["e"]
axKL.set_title("Kullback–Leibler Divergence from  \n Simulated EBM-KF-uf State Ensemble \n to Pred. Distn.")
axKL.plot(0,KL_div(all_xhatd,xh1d, stdPd),'*',markeredgewidth=1, markeredgecolor='k',color=ekf.colorstate, markersize=10)
KL_1d=np.empty([sms])
KL_2d=np.empty([sms*(sms-1)]);
nsKL = int(sms*(sms-1)/2)
KL_3d=np.empty([nsKL]); KL_3pd=np.empty([nsKL])
KL_8d=np.empty([nsKL]); KL_8pd=np.empty([nsKL])
for i in range(sms):
    KL_1d[i]=KL_div(all_xhatd,all_xhatd[i], all_Pd[i])
gidx = int(0.25 * (sms - 1)+.5)
goodKLd = np.argpartition(KL_1d, gidx)[gidx]
worstKLd=np.argmax(KL_1d)
ax4.plot(mean_qqyd[goodKLd],stdev_qqyd[goodKLd],'.',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',
         label="Good, KL="+format(KL_1d[goodKLd], ".2f"),markersize=10)
worstKL=np.argmax(KL_1d)
ax4.plot(mean_qqyd[worstKLd],stdev_qqyd[worstKLd],'s',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',
         label="Worst, KL="+format(KL_1d[worstKLd], ".2f"),markersize=5)
ax4.legend(loc=[0.05,.5],fontsize="8",ncol=2)
for i in range(sms):
    for j in range(sms-1):
        KL_2d[j*sms+i]=KL_div(all_xhatd,[all_xhatd[i],all_xhatd[j]], [all_Pd[i],all_Pd[j]])
for i in range(nsKL):
    a3=random.sample(range(0, sms), 3)
    KL_3d[i]=KL_div(all_xhatd,all_xhatd[a3], all_Pd[a3])
    KL_3pd[i]=KL_div(all_xhatd,all_xhatd[a3], all_Pd[a3],False)
    a8=random.sample(range(0, sms), 8)
    KL_8d[i]=KL_div(all_xhatd,all_xhatd[a8], all_Pd[a8])
    KL_8pd[i]=KL_div(all_xhatd,all_xhatd[a8], all_Pd[a8],False)
    
KL_alld=[KL_1d,KL_2d,KL_3d, KL_8d, KL_3pd,KL_8pd]
vplot=axKL.violinplot(KL_alld, positions=[1,2,3,4,5,6],showmeans=True,showextrema=False)
#for patch, color in zip(vplot['bodies'], colors):
for i in range(6):
    axKL.text(i+1,np.mean(KL_alld[i])*0.5,str(labelsKL[i]),ha='center')
    patch=vplot['bodies'][i]
    if i<4:
        patch.set_color(colorekfsims)
    else:
        patch.set_color("red")

axKL.text(0,30,"# of Sims")
axKL.set_yscale('log')
axKL.set_ylabel('Bits')
axKL.set_xticks([2.5, 5.5])
axKL.set_xticklabels(['with $\hat{p}^H_t}$ as \n prior on var.', 'sample \n variance'])



for i in range(len(ax_dict)):
    label=axlabels[i]
    ax=ax_dict2[axids[i]]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-25/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72 #The offset is messed up between -15 pdf and -45 png: pdf needs far less
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')




ax3=ax_dict2b["e"]
plt.sca(ax3)
ax3.hist(allall2_qqyd[goodKLd,~np.isnan(allall2_qqyd[goodKLd,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,6])
ax3.set_xlabel("State Std Dev",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("Good Sim-Sims Dist.")

ax3=ax_dict2b["d"]
plt.sca(ax3)
ax3.hist(allall2_qqyd[worstKLd,~np.isnan(allall2_qqyd[worstKLd,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,10])
ax3.set_xlabel("State Std Dev",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("Worst Sim-Sims Dist.")


for i in range(len(ax_dict2b)):
    label=axlabels[i]
    ax=ax_dict2b[axids2[i]]
    trans = mtransforms.ScaledTranslation(-10/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')



fig.savefig("ekf_on_sims.pdf", format="pdf")
fig.savefig("ekf_on_sims.png", dpi=400,format="png")
fig1b.savefig("ekf_on_sims_supp.pdf", format="pdf")
fig1b.savefig("ekf_on_sims_supp.png", dpi=400,format="png")

fig2.savefig("ekf_on_sims_OHCA.pdf", format="pdf")
fig2.savefig("ekf_on_sims_OHCA.png", dpi=400,format="png")
fig2b.savefig("ekf_on_sims_supp_OHCA.pdf", format="pdf")
fig2b.savefig("ekf_on_sims_supp_OHCA.png", dpi=400,format="png")

plt.show()

